Skip to content

Conversation

@skc7
Copy link
Contributor

@skc7 skc7 commented Nov 28, 2025

PR adds support of openmp 6.1 feature num_teams with dims modifier.
llvmIR translation for num_teams with dims modifier is marked as NYI.

def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;

//===----------------------------------------------------------------------===//
// V6.2: Multidimensional `num_teams` clause with dims modifier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// V6.2: Multidimensional `num_teams` clause with dims modifier
// V6.1: Multidimensional `num_teams` clause with dims modifier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

// If dims not specified but we have values, it's implicitly unidimensional
if (!dims.has_value() && values.size() != 1) {
return parser.emitError(parser.getCurrentLocation())
<< "expected 1 value without dims modifier, got " << values.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< "expected 1 value without dims modifier, got " << values.size();
<< "expected 1 value without dims modifier, but got " << values.size() << " values";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks

}];
}

def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be rather called modifier instead of clause? The clause still is num_threads, but the modifier is dims.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original design was to have a separate dims modifier(with dims and values args) class and then create num_teams and thread_limit clauses from it. But this leads to both clauses having the same argument names and when added to teams Op would create an issue.

So, now created just num_teams_multi_dim clause with arguments as num_teams_dims and num_teams_values.
Will remove the old num_teams clause and replace it with num_teams_multi_dim clause and move the name back to num_teams

@skc7
Copy link
Contributor Author

skc7 commented Dec 10, 2025

Updated PR to use dims modifier arguments in the original num_teams clause itself.
This is to support the old version and new version of num_teams using the same clause.

  • omp.teams num_teams(%lb : i32 to %ub : i32)
  • omp.teams num_teams(dims(3): %lb, %ub, %ub : i32)

Updated parser, printer and verifier for the clause.

Thanks for feedback @kparzysz

@skc7 skc7 marked this pull request as ready for review December 10, 2025 07:54
@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Chaitanya (skc7)

Changes

This is WIP PR for support of openmp 6.1 feature num_teams with dims modifier.


Full diff: https://github.com/llvm/llvm-project/pull/169883.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+48-8)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+182-8)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+77-2)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+6)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8e43c4284d078..1b44873ea99b1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -974,22 +974,62 @@ class OpenMP_NumTeamsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+    Variadic<AnyInteger>:$num_teams_values,
     Optional<AnyInteger>:$num_teams_lower,
     Optional<AnyInteger>:$num_teams_upper
   );
 
   let optAssemblyFormat = [{
-    `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
-                      $num_teams_upper `:` type($num_teams_upper) `)`
+    `num_teams` `(` custom<NumTeamsClause>(
+      $num_teams_dims, $num_teams_values, type($num_teams_values),
+      $num_teams_lower, type($num_teams_lower),
+      $num_teams_upper, type($num_teams_upper)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_teams_upper` and `num_teams_lower` arguments specify the
-    limit on the number of teams to be created. If only the upper bound is
-    specified, it acts as if the lower bound was set to the same value. It is
-    not allowed to set `num_teams_lower` if `num_teams_upper` is not specified.
-    They define a closed range, where both the lower and upper bounds are
-    included.
+    The `num_teams` clause specifies the bounds on the league space formed by the
+    construct on which it appears.
+
+    With dims modifier: (OpenMP 6.1 requirement)
+    - Uses `num_teams_dims` (dimension count) and `num_teams_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_teams(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_teams(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_teams_upper` and optional `num_teams_lower`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_teams(lower : type to upper : type)` or `num_teams(to upper : type)`
+    - Example: `num_teams(%lb : i32 to %ub : i32)` or `num_teams(to %ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumTeamsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumTeamsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumTeamsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+      assert(index < getDimensionValues().size() &&
+             "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0d6b2870c625a..6f56833b3b76a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2620,7 +2620,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+                 clauses.ifExpr, clauses.numTeamsDims, clauses.numTeamsValues,
+                 clauses.numTeamsLower, clauses.numTeamsUpper,
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
                  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
                  clauses.reductionVars,
@@ -2642,14 +2643,57 @@ LogicalResult TeamsOp::verify() {
                      "in any OpenMP dialect operations");
 
   // Check for num_teams clause restrictions
-  if (auto numTeamsLowerBound = getNumTeamsLower()) {
-    auto numTeamsUpperBound = getNumTeamsUpper();
-    if (!numTeamsUpperBound)
-      return emitError("expected num_teams upper bound to be defined if the "
-                       "lower bound is defined");
-    if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
+  auto numTeamsDims = getNumTeamsDims();
+  auto numTeamsValues = getNumTeamsValues();
+  auto numTeamsLower = getNumTeamsLower();
+  auto numTeamsUpper = getNumTeamsUpper();
+
+  // Cannot use both dims modifier and unidimensional style
+  if (numTeamsDims.has_value() && (numTeamsLower || numTeamsUpper)) {
+    return emitError(
+        "num_teams with dims modifier cannot be used together with "
+        "lower/upper bounds (unidimensional style)");
+  }
+
+  // With dims modifier (multidimensional)
+  if (numTeamsDims.has_value()) {
+    if (numTeamsValues.empty()) {
+      return emitError(
+          "num_teams dims modifier requires values to be specified");
+    }
+
+    if (numTeamsValues.size() != static_cast<size_t>(*numTeamsDims)) {
+      return emitError("num_teams dims(")
+             << *numTeamsDims << ") specified but " << numTeamsValues.size()
+             << " values provided";
+    }
+
+    // All values must have the same type
+    if (!numTeamsValues.empty()) {
+      Type firstType = numTeamsValues.front().getType();
+      for (auto value : numTeamsValues) {
+        if (value.getType() != firstType) {
+          return emitError(
+              "num_teams dims modifier requires all values to have "
+              "the same type");
+        }
+      }
+    }
+  } else {
+    // Without dims modifier
+    if (!numTeamsValues.empty()) {
       return emitError(
-          "expected num_teams upper bound and lower bound to be the same type");
+          "num_teams values can only be specified with dims modifier");
+    }
+
+    if (numTeamsLower) {
+      if (!numTeamsUpper)
+        return emitError("expected num_teams upper bound to be defined if the "
+                         "lower bound is defined");
+      if (numTeamsLower.getType() != numTeamsUpper.getType())
+        return emitError("expected num_teams upper bound and lower bound to be "
+                         "the same type");
+    }
   }
 
   // Check for allocate clause restrictions
@@ -4453,6 +4497,136 @@ LogicalResult WorkdistributeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Helper: Parse dims modifier with values
+//===----------------------------------------------------------------------===//
+// Parses: dims(N): values : type (single type for all values)
+static ParseResult parseDimsModifierWithValues(
+    OpAsmParser &parser, IntegerAttr &dimsAttr,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    SmallVectorImpl<Type> &types) {
+  if (failed(parser.parseOptionalKeyword("dims"))) {
+    return failure();
+  }
+
+  // Parse (N): values : type
+  int64_t dimsValue;
+  if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+      parser.parseRParen() || parser.parseColon()) {
+    return failure();
+  }
+
+  if (parser.parseOperandList(values) || parser.parseColon()) {
+    return failure();
+  }
+
+  // Parse single type (all values have same type)
+  Type valueType;
+  if (parser.parseType(valueType)) {
+    return failure();
+  }
+
+  // Fill types vector with same type for all values
+  types.assign(values.size(), valueType);
+
+  dimsAttr = parser.getBuilder().getI64IntegerAttr(dimsValue);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_teams clause with dims modifier
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumTeamsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                    SmallVectorImpl<Type> &types,
+                    std::optional<OpAsmParser::UnresolvedOperand> &lowerBound,
+                    Type &lowerBoundType,
+                    std::optional<OpAsmParser::UnresolvedOperand> &upperBound,
+                    Type &upperBoundType) {
+
+  // Format: num_teams(dims(N): values : type)
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  // Format: num_teams(to upper : type)
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    OpAsmParser::UnresolvedOperand upperOperand;
+    if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+        parser.parseType(upperBoundType)) {
+      return failure();
+    }
+    upperBound = upperOperand;
+    return success();
+  }
+
+  // Format: num_teams(lower : type to upper : type)
+  OpAsmParser::UnresolvedOperand lowerOperand;
+  if (parser.parseOperand(lowerOperand) || parser.parseColon() ||
+      parser.parseType(lowerBoundType)) {
+    return failure();
+  }
+
+  if (failed(parser.parseKeyword("to"))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 'to' keyword in num_teams clause";
+  }
+
+  OpAsmParser::UnresolvedOperand upperOperand;
+  if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+      parser.parseType(upperBoundType)) {
+    return failure();
+  }
+
+  lowerBound = lowerOperand;
+  upperBound = upperOperand;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Helper: Print dims modifier with values
+//===----------------------------------------------------------------------===//
+// Prints: dims(N): values : type (single type for all values)
+static void printDimsModifierWithValues(OpAsmPrinter &p, IntegerAttr dimsAttr,
+                                        OperandRange values, TypeRange types) {
+  if (dimsAttr) {
+    p << "dims(" << dimsAttr.getInt() << "): ";
+  }
+
+  p.printOperands(values);
+
+  // Print single type
+  p << " : ";
+  if (!types.empty()) {
+    p << types.front();
+  }
+}
+
+static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
+                                IntegerAttr dimsAttr, OperandRange values,
+                                TypeRange types, Value lowerBound,
+                                Type lowerBoundType, Value upperBound,
+                                Type upperBoundType) {
+  if (!values.empty()) {
+    // Multidimensional: dims(N): values : type
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  } else if (upperBound) {
+    if (lowerBound) {
+      // Both bounds: lower : type to upper : type
+      p.printOperand(lowerBound);
+      p << " : " << lowerBoundType << " to ";
+      p.printOperand(upperBound);
+      p << " : " << upperBoundType;
+    } else {
+      // Upper only: to upper : type
+      p << " to ";
+      p.printOperand(upperBound);
+      p << " : " << upperBoundType;
+    }
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..836cac9a53707 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
     // expected-error @below {{expected equal sizes for allocate and allocator variables}}
     "omp.teams" (%data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
     omp.terminator
   }
   return
@@ -1451,7 +1451,82 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{num_teams dims(3) specified but 2 values provided}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {num_teams_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_with_bounds() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    %lb = arith.constant 3 : i32
+    %ub = arith.constant 4 : i32
+    // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds (unidimensional style)}}
+    "omp.teams" (%v0, %v1, %lb, %ub) ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_values_without_dims() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{num_teams values can only be specified with dims modifier}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_no_values() {
+  omp.target {
+    // expected-error @below {{num_teams dims modifier requires values to be specified}}
+    "omp.teams" () ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_type_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i64
+    // expected-error @below {{num_teams dims modifier requires all values to have the same type}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
     omp.terminator
   }
   return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ac29e20907b55..3633a4be1eb62 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
+  // CHECK: omp.teams num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  omp.teams num_teams(dims(3): %lb, %ub, %ub : i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   // Test if.
   // CHECK: omp.teams if(%{{.+}})
   omp.teams if(%if_cond) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants